-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[aarch64] Add Sbgemm kernel to accelerate fp32 tensor matmul with bfloat16 #17031
Conversation
appreciate if someone can review this PR. |
be947da
to
2aef2f3
Compare
Hi @snnn , would you be able to review and provide feedback on this PR? appreciate your time. |
2aef2f3
to
9b51325
Compare
Hi, I have rebased the PR to resolve the merge conflicts. I'm happy to address any feedback you may have. Thank you! |
I have checked out the changes and run performance test and accuracy tests with and without flag using |
eb257ff
to
83a6f6e
Compare
Hi @chenfucn , @yufenglee , I have updated the PR (1) to move to the newer gemm interface and (2) to add session option based fastmath mode control. Please review and let me know your feedback. |
83a6f6e
to
2fffd44
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed, please add mlas unit tests that call the kernel directly with different shapes are other parameters.
/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline, Linux QNN CI Pipeline, MacOS CI Pipeline |
/azp run ONNX Runtime Web CI Pipeline, Windows ARM64 QNN CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed |
Azure Pipelines successfully started running 7 pipeline(s). |
Azure Pipelines successfully started running 9 pipeline(s). |
Thanks for the review, I will update the PR to address this and also add unit tests. |
2fffd44
to
cef62df
Compare
I have updated the PR to address all the feedback so far and also the learnings from my other qgemm PR. Next, adding ort optimizer and provider tests to test the fastmath session. |
@snadampal did you push your change to Github? |
not yet,planning to push the code format changes along with the session name change |
Hi @skottmckay , appreciate your response on this.
|
I think I would consider the first name as something that points me to where I would find the setting being used. e.g. 'optimization' means look in the optimizer project. I would say it's inferred you're configuring something in the session as you're using SessionOptions (vs. say RunOptions). Based on that, I would vote for 'mlas.' as the prefix. The name also seems a little too generic as it sounds like it would apply to MLAS as a whole. Unless we think there will be some other fastpath that applies to MLAS GEMM in general, a more specific name would be clearer. e.g. Or alternatively the platform/datatype could be in the value and you could parse that. e.g. |
thank you, I see your point. bf16 and f16 are the potential fastmath options, but on aarch64, so far I see interest for bf16 fastmath alone. I agree that there may not be multiple of these for different platforms, so I will go ahead with a simple config key.
|
Added SbgemmKernel assembly implementation with bfmmla instructions and sbgemm utility functions to prepack Matrix B along with conversion to bfloat16.
sbgemm kernel is invoked when fastmath mode is enabled and HW supports the bf16 instruction set. It's disabled by default, set the following session option to 1 to enable it. "kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16"
f45ef1d
to
d6d48c3
Compare
Update the PR for the session name and other points discussed so far including the clang-formatting. Tested
|
/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline, Linux QNN CI Pipeline, MacOS CI Pipeline, Windows ARM64 QNN CI Pipeline |
/azp run Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows x64 QNN CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed |
Azure Pipelines successfully started running 8 pipeline(s). |
1 similar comment
Azure Pipelines successfully started running 8 pipeline(s). |
Thanks to @chenfucn , @snnn , @skottmckay and @yufenglee for the great feedback and merging the PR! |
…oat16 (#17031) ### Description This PR adds SbgemmKernel for aarch64. This includes Sbegmm kernel to implement matrix multiplication with bfloat16 SIMD instructions (bfmmla) and MatMul operator changes to invoke the Sbgemm kernel. To enable Sbgemm kernel, set the following session option: "kOrtSessionOptionsGemmFastMathMode" The PR also adds new test cases for mlas and ort. ### Motivation and Context This is to improve MatMul performance on aarch64 platform. I have run the below benchmarking script (bert , roberta and gpt2 model inference) on AWS Graviton3 based c7g.4xl instance and observed 1.2x -1.76x performance improvement compared to sgemm (fp32) kernel performance. ``` cd onnxruntime/python/tools/transformers python3 benchmark.py ``` And the unit test precision results are matching to sgemm kernel results. `./build.sh --config RelWithDebInfo --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync `
@snadampal , thanks for making ONNX Runtime better. Welcome to bring more changes to us. You have my email. Do not hesitate to contact me anytime when you need help on reviewing PRs. |
Description
This PR adds SbgemmKernel for aarch64. This includes Sbegmm kernel to implement matrix multiplication with bfloat16 SIMD instructions (bfmmla) and MatMul operator changes to invoke the Sbgemm kernel. To enable Sbgemm kernel, set the following session option:
"kOrtSessionOptionsGemmFastMathMode"
The PR also adds new test cases for mlas and ort.
Motivation and Context
This is to improve MatMul performance on aarch64 platform.
I have run the below benchmarking script (bert , roberta and gpt2 model inference) on AWS Graviton3 based c7g.4xl instance and observed 1.2x -1.76x performance improvement compared to sgemm (fp32) kernel performance.
And the unit test precision results are matching to sgemm kernel results.
./build.sh --config RelWithDebInfo --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync